Skip to content

[ROCm] Fix allreduce + RMSNorm fusion pattern matchin#41767

Open
rbrugaro-amd wants to merge 3 commits intovllm-project:mainfrom
rbrugaro-amd:rbrugaro/fix-allreduce-rms-fusion
Open

[ROCm] Fix allreduce + RMSNorm fusion pattern matchin#41767
rbrugaro-amd wants to merge 3 commits intovllm-project:mainfrom
rbrugaro-amd:rbrugaro/fix-allreduce-rms-fusion

Conversation

@rbrugaro-amd
Copy link
Copy Markdown
Contributor

@rbrugaro-amd rbrugaro-amd commented May 6, 2026

Summary

Fixes two issues that broke the allreduce + RMSNorm fusion pass introduced in #37646, caused by subsequent refactoring in #36823

  1. torch.empty_liketorch.zeros_like in AiterAllreduceFusedRMSNormPattern._replacement (allreduce_rms_fusion.py):
    The fused allreduce+rmsnorm kernel always adds res_inp; using empty_like leaves undefined values that corrupt outputs. Changed to zeros_like so the add is a no-op when residual is freshly created.

  2. Conditional variance_size_override argument in RMSNorm.forward_native (layernorm.py):
    After the IR refactoring, ir.ops.rms_norm and ir.ops.fused_add_rms_norm.maybe_inplace were unconditionally passed self.variance_size_override (even when None). This produced 4-argument calls in the FX graph, but the fusion patterns expect 3 arguments. The mismatch prevented pattern matching entirely. Fixed by conditionally unpacking variance_size_override only when it is not None.

Testing

Tested with Kimi-K2-Thinking-MXFP4 on 4x MI355X (TP=4)

  • vllm: 0.20.1rc1.dev153+gcfd2573f2 (base commit cfd2573f2)
  • aiter: amd-aiter 0.1.12.post2.dev126+g033d8b9db
  • torch: 2.10.0+git8514f05

Fusion pass results (confirmed via VLLM_DEBUG_DUMP_PATH graph dumps and custom logging):

  • all_reduce_fusion_pass: 244 pattern matches across 2 compile ranges (122 per range)
  • mla_dual_rms_norm_fusion_pass: 183 matches
  • Graph reduced from ~4491 → ~4367 nodes after fusion/cleanup passes
  • fused_add_rms_norm=aiter implementation selected

Signed-off-by: Rita Brugarolas Brufau <rita.brugarolasbrufau@amd.com>
@rbrugaro-amd rbrugaro-amd changed the title [ROCm] Fix allreduce + RMSNorm fusion pattern matchin [WIP][ROCm] Fix allreduce + RMSNorm fusion pattern matchin May 6, 2026
@mergify mergify Bot added the rocm Related to AMD ROCm label May 6, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD May 6, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the allreduce_rms_fusion pass to initialize the residual tensor with zeros instead of uninitialized memory. Additionally, it modifies the RMSNorm forward pass in layernorm.py to conditionally pass the variance_size_override argument only when it is not None. I have no feedback to provide as there were no review comments.

@attila-dusnoki-htec
Copy link
Copy Markdown

I tested this with 577b9623e6f8801698d411f4b04269326f5afbe2 base commit + PR#40392 change with kimi-k2.5-mxfp4 model.

Without this patch and "fuse_allreduce_rms": false:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9554|±  |0.052|
|     |       |strict-match    |     5|exact_match|↑  |0.9417|±  |0.0067|

Without this patch and "fuse_allreduce_rms": true:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.008|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.004|±  |0.0040|

With this patch and "fuse_allreduce_rms": true:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.952|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.952|±  |0.0135|

@rbrugaro-amd rbrugaro-amd marked this pull request as ready for review May 7, 2026 15:50
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

self.weight.data if self.pass_weight else None,
self.variance_epsilon,
self.variance_size_override,
*(
Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just fix the patterns in the pass instead

@rbrugaro-amd rbrugaro-amd changed the title [WIP][ROCm] Fix allreduce + RMSNorm fusion pattern matchin [ROCm] Fix allreduce + RMSNorm fusion pattern matchin May 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

3 participants